Some times I get this error:
Core Computation: 5614.850s ( 6.6%) JAX Overhead: 60470.898s ( 70.9%) SVGD Algorithm: 8773.522s ( 10.3%) NumPy Operations: 0.000s ( 0.0%) Other: 9936.396s ( 11.6%)
JAX Overhead Breakdown:
vmap/pmap 15285.256s ( 25.3% of JAX, 17.9% total) Other JAX 14738.447s ( 24.4% of JAX, 17.3% total) Dispatch 12227.066s ( 20.2% of JAX, 14.3% total) Callbacks/FFI 9097.319s ( 15.0% of JAX, 10.7% total) Primitives 5937.572s ( 9.8% of JAX, 7.0% total) Tracing 3057.867s ( 5.1% of JAX, 3.6% total) Array ops 127.371s ( 0.2% of JAX, 0.1% total)
Top Computation Functions:
pmf_function 2808.155s ( 3.3%) [ 840000 calls]
_compute_pmf_from_ctypes 2806.695s ( 3.3%) [ 840000 calls]
Top JAX Dispatch Functions:
call_wrapped 3057.210s ( 3.6%) [ 2125 calls]
bind 3056.626s ( 3.6%) [ 4610 calls]
_true_bind 3056.620s ( 3.6%) [ 4610 calls]
bind_with_trace 3056.610s ( 3.6%) [ 4628 calls]
Top JAX vmap/pmap Functions:
vmap_f 3057.616s ( 3.6%) [ 2000 calls]
_batch_outer 3057.158s ( 3.6%) [ 2000 calls]
_batch_inner 3057.107s ( 3.6%) [ 2000 calls]
flatten_fun_for_vmap 3056.901s ( 3.6%) [ 2000 calls]
_pjit_batcher 3056.473s ( 3.6%) [ 2000 calls]
Top JAX Array Functions:
device_put 127.371s ( 0.1%) [ 840001 calls]
Top JAX Primitive Functions:
process_primitive 3056.528s ( 3.6%) [ 2100 calls]
process_primitive 2881.044s ( 3.4%) [ 2000 calls]
Top JAX Callback/FFI Functions:
_wrapped_callback 3037.308s ( 3.6%) [ 840000 calls]
_callback 3030.319s ( 3.6%) [ 840000 calls]
pure_callback_impl 3029.693s ( 3.6%) [ 840000 calls]
Top SVGD Functions:
svgd_step 3055.878s ( 3.6%) [ 2000 calls]
fit 2858.822s ( 3.4%) [ 3 calls]
run_svgd 2858.821s ( 3.4%) [ 3 calls]
Optimization Recommendations:
• HIGH vmap/pmap overhead (17.9%): - vmap overhead is normal for vectorized operations - Consider using explicit loops if vmap is over small batches - Check if batch size can be increased
================================================================================
:::
:::
::: {#f1857753 .cell execution_count=7}
``` {.python .cell-code}
svgd.plot_convergence() ;
:::
(<Figure size 700x300 with 2 Axes>,
array([<Axes: title={'center': 'Mean Convergence'}, xlabel='SVGD Iteration', ylabel='Posterior Mean'>,
<Axes: title={'center': 'Std Convergence'}, xlabel='SVGD Iteration', ylabel='Posterior Std'>],
dtype=object))
_build c_api numpy.css
_extensions custom-dark.scss numpy.theme
_freeze custom.scss objects.txt
_inv galleries pages
_quarto.yml index.qmd r_api
api logo.png styles.css
autodoc.mustache numpy-dark.theme
banner.png numpy-navbar-sidebar.css
/Users/kmt/PtDAlgorithms/.pixi/envs/default/lib/python3.13/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
pid, fd = os.forkpty()
fig, axes = svgd.plot_trace()
fig.savefig('../../galleries/examples/images/svgd_convergence.webp' , format = 'webp' )
======================================================================
SVGD Inference Summary
======================================================================
Number of particles: 60
Number of iterations: 1000
Parameter dimension: 3
Posterior estimates:
θ_0: 136.4124 ± 128.4685
95% CI: [-29.4463, 352.3441]
θ_1: 1633.8782 ± 986.7968
95% CI: [-744.5340, 3198.3957]
θ_2: 91.2288 ± 96.7373
95% CI: [-107.4353, 276.1241]
======================================================================
svgd.plot_pairwise(true_theta= true_theta,
# param_names=['jump', 'flood_left', 'flood_right'],
show_transformed= True ,
) ;
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[8] , line 1
----> 1 svgd . plot_pairwise ( true_theta = true_theta ,
2 # param_names=['jump', 'flood_left', 'flood_right'],
3 show_transformed = True ,
4 ) ;
File ~/PtDAlgorithms/src/ptdalgorithms/svgd.py:1794 , in plot_pairwise (self, true_theta, param_names, figsize, save_path, show_transformed)
1773 def plot_pairwise (self , true_theta=None , param_names=None ,
1774 figsize=None , save_path=None , show_transformed=True ):
1775 """
1776 Plot pairwise scatter plots for all parameter pairs.
1777
1778 Parameters
1779 ----------
1780 true_theta : array_like, optional
1781 True parameter values (if known) to overlay on plot
1782 param_names : list of str, optional
1783 Names for each parameter dimension
1784 figsize : tuple, optional
1785 Figure size (width, height)
1786 save_path : str, optional
1787 Path to save the plot
1788 show_transformed : bool, default=True
1789 If True, show transformed (constrained) parameter values.
1790 If False, show untransformed (unconstrained) values.
1791 Only relevant when using parameter transformations.
1792
1793 Returns
-> 1794 -------
1795 fig, axes
1796 Matplotlib figure and axes objects
1797 """
1798 if not self .is_fitted:
1799 raise RuntimeError (" Must call fit() before plotting " )
RuntimeError : Must call fit() before plotting
anim = svgd.animate_pairwise(
true_theta= [2.0 , 3.0 , 2.0 ],
param_names= ['jump' , 'flood_left' , 'flood_right' ],
thin= 20 ,
show_transformed= True ,
)
anim # Display in Jupyter
[INFO] Animation.save using <class 'matplotlib.animation.HTMLWriter'>
[INFO] figure size in inches has been adjusted from 9.0 x 6.8999999999999995 to 9.0 x 6.9
results = svgd.get_results()
results.keys()
dict_keys(['particles', 'theta_mean', 'theta_std', 'history'])